from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

class YOLODetector:
    def __init__(self, model_path='models/best.engine'):
        # 加载 TensorRT 模型
        self.model = YOLO(model_path, task="detect")
        # 英文类别名称到中文的映射
        self.class_name_mapping = {
            'pedestrian': '行人',
            'people': '人群',
            'bicycle': '自行车',
            'car': '轿车',
            'van': '面包车',
            'truck': '卡车',
            'tricycle': '三轮车',
            'awning-tricycle': '篷式三轮车',
            'bus': '公交车',
            'motor': '摩托车'
        }
        # 为每个类别设置固定的RGB颜色
        self.color_mapping = {
            'pedestrian': (71, 0, 36),      # 勃艮第红
            'people': (0, 255, 0),          # 绿色
            'bicycle': (0, 49, 83),         # 普鲁士蓝
            'car': (0, 47, 167),            # 克莱茵蓝
            'van': (128, 0, 128),           # 紫色
            'truck': (212, 72, 72),         # 缇香红
            'tricycle': (0, 49, 83),        # 橙色
            'awning-tricycle': (251, 220, 106), # 申布伦黄
            'bus': (73, 45, 34),            # 凡戴克棕
            'motor': (1, 132, 127)          # 马尔斯绿
        }
        # 初始化类别计数器
        self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
        # 初始化字体
        try:
            self.font = ImageFont.truetype("simhei.ttf", 20)
        except IOError:
            self.font = ImageFont.load_default()

    def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
        """
        对输入帧进行目标检测并返回绘制结果
        
        Args:
            frame: 输入的图像帧（BGR格式）
            conf: 置信度阈值
            iou: IOU阈值
        
        Returns:
            annotated_frame: 绘制了检测结果的图像帧
        """
        try:
            # 进行 YOLO 目标检测
            results = self.model(
                frame,
                conf=conf,
                iou=iou,
                half=True,
            )
            result = results[0]
            
            # 使用YOLO自带的绘制功能
            annotated_frame = result.plot()
            
            return annotated_frame
            
        except Exception as e:
            print(f"Detection error: {e}")
            return frame

    def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
        """
        对输入帧进行目标检测并绘制中文标注
        
        Args:
            frame: 输入的图像帧（BGR格式）
            conf: 置信度阈值
            iou: IOU阈值
        
        Returns:
            annotated_frame: 绘制了检测结果的图像帧
        """
        try:
            # 进行 YOLO 目标检测
            results = self.model(
                frame,
                conf=conf,
                iou=iou,
                # half=True,
            )
            result = results[0]
            
            # 获取原始帧的副本
            img = frame.copy()
            
            # 转换为PIL图像以绘制中文
            pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            draw = ImageDraw.Draw(pil_img)
            
            # 绘制检测结果
            for box in result.boxes:
                # 获取边框坐标
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
                
                # 获取类别ID和置信度
                cls_id = int(box.cls[0].item())
                conf = box.conf[0].item()
                
                # 获取类别名称并转换为中文
                cls_name = result.names[cls_id]
                chinese_name = self.class_name_mapping.get(cls_name, cls_name)
                
                # 获取该类别的颜色
                color = self.color_mapping.get(cls_name, (255, 255, 255))
                
                # 绘制边框
                draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
                
                # 准备标签文本
                text = f"{chinese_name} {conf:.2f}"
                text_size = draw.textbbox((0, 0), text, font=self.font)
                text_width = text_size[2] - text_size[0]
                text_height = text_size[3] - text_size[1]
                
                # 绘制标签背景（使用与边框相同的颜色）
                draw.rectangle(
                    [(x1, y1 - text_height - 4), (x1 + text_width, y1)],
                    fill=color
                )
                
                # 绘制白色文本
                draw.text(
                    (x1, y1 - text_height - 2),
                    text,
                    fill=(255, 255, 255),  # 白色文本
                    font=self.font
                )
            
            # 转换回OpenCV格式
            return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
            
        except Exception as e:
            print(f"Detection error: {e}")
            return frame
